// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//


#include "melon/string.h"
#include <pollux/common/exception/exceptions.h>
#include <melon/container/f14_map.h>
#include <ios>
#include <sstream>
#include <string>

namespace kumo {
    namespace pollux {
        namespace common {
            namespace compression {
                static const int32_t DEC_32_TABLE[] = {4, 1, 2, 1, 4, 4, 4, 4};
                static const int32_t DEC_64_TABLE[] = {0, 0, 0, -1, 0, 1, 2, 3};

                static const int32_t SIZE_OF_SHORT = 2;
                static const int32_t SIZE_OF_INT = 4;
                static const int32_t SIZE_OF_LONG = 8;

                static std::string toHex(uint64_t val) {
                    std::ostringstream out;
                    out << "0x" << std::hex << val;
                    return out.str();
                }

                static std::string toString(int64_t val) {
                    return melon::to<std::string>(val);
                }

                namespace {
                    class MalformedInputException : public dwio::common::ParseError {
                    public:
                        explicit MalformedInputException(int64_t off)
                            : ParseError("MalformedInputException at " + toString(off)) {
                        }

                        MalformedInputException(int64_t off, const std::string &msg)
                            : ParseError("MalformedInputException " + msg + " at " + toString(off)) {
                        }

                        MalformedInputException(const MalformedInputException &other)
                            : ParseError(other.what()) {
                        }

                        MalformedInputException &operator=(const MalformedInputException &) = delete;

                        ~MalformedInputException() noexcept override = default;
                    };
                } // namespace

                uint64_t lzoDecompress(
                    const char *inputAddress,
                    const char *inputLimit,
                    char *outputAddress,
                    char *outputLimit) {
                    // nothing compresses to nothing
                    if (inputAddress == inputLimit) {
                        return 0;
                    }

                    // maximum offset in buffers to which it's safe to write long-at-a-time
                    char *const fastOutputLimit = outputLimit - SIZE_OF_LONG;

                    // LZO can concat two blocks together so, decode until the input data is
                    // consumed
                    const char *input = inputAddress;
                    char *output = outputAddress;
                    while (input < inputLimit) {
                        //
                        // Note: For safety some of the code below may stop decoding early or
                        // skip decoding, because input is not available.  This makes the code
                        // safe, and since LZO requires an explicit "stop" command, the decoder
                        // will still throw a exception.
                        //

                        bool firstCommand = true;
                        uint32_t lastLiteralLength = 0;
                        while (true) {
                            if (input >= inputLimit) {
                                throw MalformedInputException(input - inputAddress);
                            }
                            uint32_t command = *(input++) & 0xFF;
                            if (command == 0x11) {
                                break;
                            }

                            // Commands are described using a bit pattern notation:
                            // 0: bit is not set
                            // 1: bit is set
                            // L: part of literal length
                            // P: part of match offset position
                            // M: part of match length
                            // ?: see documentation in command decoder

                            int32_t matchLength;
                            int32_t matchOffset;
                            uint32_t literalLength;
                            if ((command & 0xf0) == 0) {
                                if (lastLiteralLength == 0) {
                                    // 0b0000_LLLL (0bLLLL_LLLL)*

                                    // copy length :: fixed
                                    //   0
                                    matchOffset = 0;

                                    // copy offset :: fixed
                                    //   0
                                    matchLength = 0;

                                    // literal length - 3 :: variable bits :: valid range [4..]
                                    //   3 + variableLength(command bits [0..3], 4)
                                    literalLength = command & 0xf;
                                    if (literalLength == 0) {
                                        literalLength = 0xf;

                                        uint32_t nextByte = 0;
                                        while (input < inputLimit && (nextByte = *(input++) & 0xFF) == 0) {
                                            literalLength += 0xff;
                                        }
                                        literalLength += nextByte;
                                    }
                                    literalLength += 3;
                                } else if (lastLiteralLength <= 3) {
                                    // 0b0000_PPLL 0bPPPP_PPPP

                                    // copy length: fixed
                                    //   3
                                    matchLength = 3;

                                    // copy offset :: 12 bits :: valid range [2048..3071]
                                    //   [0..1] from command [2..3]
                                    //   [2..9] from trailer [0..7]
                                    //   [10] unset
                                    //   [11] set
                                    if (input >= inputLimit) {
                                        throw MalformedInputException(input - inputAddress);
                                    }
                                    matchOffset = (command & 0xc) >> 2;
                                    matchOffset |= (*(input++) & 0xFF) << 2;
                                    matchOffset |= 0x800;

                                    // literal length :: 2 bits :: valid range [0..3]
                                    //   [0..1] from command [0..1]
                                    literalLength = (command & 0x3);
                                } else {
                                    // 0b0000_PPLL 0bPPPP_PPPP

                                    // copy length :: fixed
                                    //   2
                                    matchLength = 2;

                                    // copy offset :: 10 bits :: valid range [0..1023]
                                    //   [0..1] from command [2..3]
                                    //   [2..9] from trailer [0..7]
                                    if (input >= inputLimit) {
                                        throw MalformedInputException(input - inputAddress);
                                    }
                                    matchOffset = (command & 0xc) >> 2;
                                    matchOffset |= (*(input++) & 0xFF) << 2;

                                    // literal length :: 2 bits :: valid range [0..3]
                                    //   [0..1] from command [0..1]
                                    literalLength = (command & 0x3);
                                }
                            } else if (firstCommand) {
                                // first command has special handling when high nibble is set
                                matchLength = 0;
                                matchOffset = 0;
                                literalLength = command - 17;
                            } else if ((command & 0xf0) == 0x10) {
                                // 0b0001_?MMM (0bMMMM_MMMM)* 0bPPPP_PPPP_PPPP_PPLL

                                // copy length - 2 :: variable bits :: valid range [3..]
                                //   2 + variableLength(command bits [0..2], 3)
                                matchLength = command & 0x7;
                                if (matchLength == 0) {
                                    matchLength = 0x7;

                                    int32_t nextByte = 0;
                                    while (input < inputLimit && (nextByte = *(input++) & 0xFF) == 0) {
                                        matchLength += 0xff;
                                    }
                                    matchLength += nextByte;
                                }
                                matchLength += 2;

                                // read trailer
                                if (input + SIZE_OF_SHORT > inputLimit) {
                                    throw MalformedInputException(input - inputAddress);
                                }
                                uint32_t trailer = *reinterpret_cast<const uint16_t *>(input) & 0xFFFF;
                                input += SIZE_OF_SHORT;

                                // copy offset :: 16 bits :: valid range [32767..49151]
                                //   [0..13] from trailer [2..15]
                                //   [14] if command bit [3] unset
                                //   [15] if command bit [3] set
                                matchOffset = trailer >> 2;
                                if ((command & 0x8) == 0) {
                                    matchOffset |= 0x4000;
                                } else {
                                    matchOffset |= 0x8000;
                                }
                                matchOffset--;

                                // literal length :: 2 bits :: valid range [0..3]
                                //   [0..1] from trailer [0..1]
                                literalLength = trailer & 0x3;
                            } else if ((command & 0xe0) == 0x20) {
                                // 0b001M_MMMM (0bMMMM_MMMM)* 0bPPPP_PPPP_PPPP_PPLL

                                // copy length - 2 :: variable bits :: valid range [3..]
                                //   2 + variableLength(command bits [0..4], 5)
                                matchLength = command & 0x1f;
                                if (matchLength == 0) {
                                    matchLength = 0x1f;

                                    int nextByte = 0;
                                    while (input < inputLimit && (nextByte = *(input++) & 0xFF) == 0) {
                                        matchLength += 0xff;
                                    }
                                    matchLength += nextByte;
                                }
                                matchLength += 2;

                                // read trailer
                                if (input + SIZE_OF_SHORT > inputLimit) {
                                    throw MalformedInputException(input - inputAddress);
                                }
                                int32_t trailer = *reinterpret_cast<const int16_t *>(input) & 0xFFFF;
                                input += SIZE_OF_SHORT;

                                // copy offset :: 14 bits :: valid range [0..16383]
                                //  [0..13] from trailer [2..15]
                                matchOffset = trailer >> 2;

                                // literal length :: 2 bits :: valid range [0..3]
                                //   [0..1] from trailer [0..1]
                                literalLength = trailer & 0x3;
                            } else if ((command & 0xc0) != 0) {
                                // 0bMMMP_PPLL 0bPPPP_PPPP

                                // copy length - 1 :: 3 bits :: valid range [1..8]
                                //   [0..2] from command [5..7]
                                //   add 1
                                matchLength = (command & 0xe0) >> 5;
                                matchLength += 1;

                                // copy offset :: 11 bits :: valid range [0..4095]
                                //   [0..2] from command [2..4]
                                //   [3..10] from trailer [0..7]
                                if (input >= inputLimit) {
                                    throw MalformedInputException(input - inputAddress);
                                }
                                matchOffset = (command & 0x1c) >> 2;
                                matchOffset |= (*(input++) & 0xFF) << 3;

                                // literal length :: 2 bits :: valid range [0..3]
                                //   [0..1] from command [0..1]
                                literalLength = (command & 0x3);
                            } else {
                                throw MalformedInputException(
                                    input - inputAddress - 1, "Invalid LZO command " + toHex(command));
                            }
                            firstCommand = false;

                            // copy match
                            if (matchLength != 0) {
                                // lzo encodes match offset minus one
                                matchOffset++;

                                char *matchAddress = output - matchOffset;
                                if (matchAddress < outputAddress ||
                                    output + matchLength > outputLimit) {
                                    throw MalformedInputException(input - inputAddress);
                                }
                                char *matchOutputLimit = output + matchLength;

                                if (output > fastOutputLimit) {
                                    // slow match copy
                                    while (output < matchOutputLimit) {
                                        *(output++) = *(matchAddress++);
                                    }
                                } else {
                                    // copy repeated sequence
                                    if (matchOffset < SIZE_OF_LONG) {
                                        // 8 bytes apart so that we can copy long-at-a-time below
                                        int32_t increment32 = DEC_32_TABLE[matchOffset];
                                        int32_t decrement64 = DEC_64_TABLE[matchOffset];

                                        output[0] = *matchAddress;
                                        output[1] = *(matchAddress + 1);
                                        output[2] = *(matchAddress + 2);
                                        output[3] = *(matchAddress + 3);
                                        output += SIZE_OF_INT;
                                        matchAddress += increment32;

                                        *reinterpret_cast<int32_t *>(output) =
                                                *reinterpret_cast<int32_t *>(matchAddress);
                                        output += SIZE_OF_INT;
                                        matchAddress -= decrement64;
                                    } else {
                                        *reinterpret_cast<int64_t *>(output) =
                                                *reinterpret_cast<int64_t *>(matchAddress);
                                        matchAddress += SIZE_OF_LONG;
                                        output += SIZE_OF_LONG;
                                    }

                                    if (matchOutputLimit >= fastOutputLimit) {
                                        if (matchOutputLimit > outputLimit) {
                                            throw MalformedInputException(input - inputAddress);
                                        }

                                        while (output < fastOutputLimit) {
                                            *reinterpret_cast<int64_t *>(output) =
                                                    *reinterpret_cast<int64_t *>(matchAddress);
                                            matchAddress += SIZE_OF_LONG;
                                            output += SIZE_OF_LONG;
                                        }

                                        while (output < matchOutputLimit) {
                                            *(output++) = *(matchAddress++);
                                        }
                                    } else {
                                        while (output < matchOutputLimit) {
                                            *reinterpret_cast<int64_t *>(output) =
                                                    *reinterpret_cast<int64_t *>(matchAddress);
                                            matchAddress += SIZE_OF_LONG;
                                            output += SIZE_OF_LONG;
                                        }
                                    }
                                }
                                output = matchOutputLimit; // correction in case we over-copied
                            }

                            // copy literal
                            char *literalOutputLimit = output + literalLength;
                            if (literalOutputLimit > fastOutputLimit ||
                                input + literalLength > inputLimit - SIZE_OF_LONG) {
                                if (literalOutputLimit > outputLimit) {
                                    throw MalformedInputException(input - inputAddress);
                                }

                                // slow, precise copy
                                memcpy(output, input, literalLength);
                                input += literalLength;
                                output += literalLength;
                            } else {
                                // fast copy. We may over-copy but there's enough room in input
                                // and output to not overrun them
                                do {
                                    *reinterpret_cast<int64_t *>(output) =
                                            *reinterpret_cast<const int64_t *>(input);
                                    input += SIZE_OF_LONG;
                                    output += SIZE_OF_LONG;
                                } while (output < literalOutputLimit);
                                // adjust index if we over-copied
                                input -= (output - literalOutputLimit);
                                output = literalOutputLimit;
                            }
                            lastLiteralLength = literalLength;
                        }

                        if (input + SIZE_OF_SHORT > inputLimit &&
                            *reinterpret_cast<const int16_t *>(input) != 0) {
                            throw MalformedInputException(input - inputAddress);
                        }
                        input += SIZE_OF_SHORT;
                    }

                    return static_cast<uint64_t>(output - outputAddress);
                }
            } // namespace compression
        } // namespace common
    } // namespace pollux
} // namespace kumo
